{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from exp.nb_03 import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DataBunch/Learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=4799)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train,y_train,x_valid,y_valid = get_data()\n",
    "train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n",
    "nh,bs = 50,64\n",
    "c = y_train.max().item()+1\n",
    "loss_func = F.cross_entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Factor out the connected pieces of info out of the fit() argument list\n",
    "\n",
    "`fit(epochs, model, loss_func, opt, train_dl, valid_dl)`\n",
    "\n",
    "Let's replace it with something that looks like this:\n",
    "\n",
    "`fit(1, learn)`\n",
    "\n",
    "This will allow us to tweak what's happening inside the training loop in other places of the code because the `Learner` object will be mutable, so changing any of its attribute elsewhere will be seen in our training loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=5363)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class DataBunch():\n",
    "    def __init__(self, train_dl, valid_dl, c=None):\n",
    "        self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c\n",
    "        \n",
    "    @property\n",
    "    def train_ds(self): return self.train_dl.dataset\n",
    "        \n",
    "    @property\n",
    "    def valid_ds(self): return self.valid_dl.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def get_model(data, lr=0.5, nh=50):\n",
    "    m = data.train_ds.x.shape[1]\n",
    "    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))\n",
    "    return model, optim.SGD(model.parameters(), lr=lr)\n",
    "\n",
    "class Learner():\n",
    "    def __init__(self, model, opt, loss_func, data):\n",
    "        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(*get_model(data), loss_func, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(epochs, learn):\n",
    "    for epoch in range(epochs):\n",
    "        learn.model.train()\n",
    "        for xb,yb in learn.data.train_dl:\n",
    "            loss = learn.loss_func(learn.model(xb), yb)\n",
    "            loss.backward()\n",
    "            learn.opt.step()\n",
    "            learn.opt.zero_grad()\n",
    "\n",
    "        learn.model.eval()\n",
    "        with torch.no_grad():\n",
    "            tot_loss,tot_acc = 0.,0.\n",
    "            for xb,yb in learn.data.valid_dl:\n",
    "                pred = learn.model(xb)\n",
    "                tot_loss += learn.loss_func(pred, yb)\n",
    "                tot_acc  += accuracy (pred,yb)\n",
    "        nv = len(learn.data.valid_dl)\n",
    "        print(epoch, tot_loss/nv, tot_acc/nv)\n",
    "    return tot_loss/nv, tot_acc/nv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor(0.2194) tensor(0.9372)\n"
     ]
    }
   ],
   "source": [
    "loss,acc = fit(1, learn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CallbackHandler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This was our training loop (without validation) from the previous notebook, with the inner loop contents factored out:\n",
    "\n",
    "```python\n",
    "def one_batch(xb,yb):\n",
    "    pred = model(xb)\n",
    "    loss = loss_func(pred, yb)\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "    opt.zero_grad()\n",
    "    \n",
    "def fit():\n",
    "    for epoch in range(epochs):\n",
    "        for b in train_dl: one_batch(*b)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add callbacks so we can remove complexity from loop, and make it flexible:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=5628)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def one_batch(xb, yb, cb):\n",
    "    if not cb.begin_batch(xb,yb): return\n",
    "    loss = cb.learn.loss_func(cb.learn.model(xb), yb)\n",
    "    if not cb.after_loss(loss): return\n",
    "    loss.backward()\n",
    "    if cb.after_backward(): cb.learn.opt.step()\n",
    "    if cb.after_step(): cb.learn.opt.zero_grad()\n",
    "\n",
    "def all_batches(dl, cb):\n",
    "    for xb,yb in dl:\n",
    "        one_batch(xb, yb, cb)\n",
    "        if cb.do_stop(): return\n",
    "\n",
    "def fit(epochs, learn, cb):\n",
    "    if not cb.begin_fit(learn): return\n",
    "    for epoch in range(epochs):\n",
    "        if not cb.begin_epoch(epoch): continue\n",
    "        all_batches(learn.data.train_dl, cb)\n",
    "        \n",
    "        if cb.begin_validate():\n",
    "            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)\n",
    "        if cb.do_stop() or not cb.after_epoch(): break\n",
    "    cb.after_fit()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Callback():\n",
    "    def begin_fit(self, learn):\n",
    "        self.learn = learn\n",
    "        return True\n",
    "    def after_fit(self): return True\n",
    "    def begin_epoch(self, epoch):\n",
    "        self.epoch=epoch\n",
    "        return True\n",
    "    def begin_validate(self): return True\n",
    "    def after_epoch(self): return True\n",
    "    def begin_batch(self, xb, yb):\n",
    "        self.xb,self.yb = xb,yb\n",
    "        return True\n",
    "    def after_loss(self, loss):\n",
    "        self.loss = loss\n",
    "        return True\n",
    "    def after_backward(self): return True\n",
    "    def after_step(self): return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CallbackHandler():\n",
    "    def __init__(self,cbs=None):\n",
    "        self.cbs = cbs if cbs else []\n",
    "\n",
    "    def begin_fit(self, learn):\n",
    "        self.learn,self.in_train = learn,True\n",
    "        learn.stop = False\n",
    "        res = True\n",
    "        for cb in self.cbs: res = res and cb.begin_fit(learn)\n",
    "        return res\n",
    "\n",
    "    def after_fit(self):\n",
    "        res = not self.in_train\n",
    "        for cb in self.cbs: res = res and cb.after_fit()\n",
    "        return res\n",
    "    \n",
    "    def begin_epoch(self, epoch):\n",
    "        self.learn.model.train()\n",
    "        self.in_train=True\n",
    "        res = True\n",
    "        for cb in self.cbs: res = res and cb.begin_epoch(epoch)\n",
    "        return res\n",
    "\n",
    "    def begin_validate(self):\n",
    "        self.learn.model.eval()\n",
    "        self.in_train=False\n",
    "        res = True\n",
    "        for cb in self.cbs: res = res and cb.begin_validate()\n",
    "        return res\n",
    "\n",
    "    def after_epoch(self):\n",
    "        res = True\n",
    "        for cb in self.cbs: res = res and cb.after_epoch()\n",
    "        return res\n",
    "    \n",
    "    def begin_batch(self, xb, yb):\n",
    "        res = True\n",
    "        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)\n",
    "        return res\n",
    "\n",
    "    def after_loss(self, loss):\n",
    "        res = self.in_train\n",
    "        for cb in self.cbs: res = res and cb.after_loss(loss)\n",
    "        return res\n",
    "\n",
    "    def after_backward(self):\n",
    "        res = True\n",
    "        for cb in self.cbs: res = res and cb.after_backward()\n",
    "        return res\n",
    "\n",
    "    def after_step(self):\n",
    "        res = True\n",
    "        for cb in self.cbs: res = res and cb.after_step()\n",
    "        return res\n",
    "    \n",
    "    def do_stop(self):\n",
    "        try:     return self.learn.stop\n",
    "        finally: self.learn.stop = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestCallback(Callback):\n",
    "    def begin_fit(self,learn):\n",
    "        super().begin_fit(learn)\n",
    "        self.n_iters = 0\n",
    "        return True\n",
    "        \n",
    "    def after_step(self):\n",
    "        self.n_iters += 1\n",
    "        print(self.n_iters)\n",
    "        if self.n_iters>=10: self.learn.stop = True\n",
    "        return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n"
     ]
    }
   ],
   "source": [
    "fit(1, learn, cb=CallbackHandler([TestCallback()]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Runner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=5811)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "import re\n",
    "\n",
    "_camel_re1 = re.compile('(.)([A-Z][a-z]+)')\n",
    "_camel_re2 = re.compile('([a-z0-9])([A-Z])')\n",
    "def camel2snake(name):\n",
    "    s1 = re.sub(_camel_re1, r'\\1_\\2', name)\n",
    "    return re.sub(_camel_re2, r'\\1_\\2', s1).lower()\n",
    "\n",
    "class Callback():\n",
    "    _order=0\n",
    "    def set_runner(self, run): self.run=run\n",
    "    def __getattr__(self, k): return getattr(self.run, k)\n",
    "    @property\n",
    "    def name(self):\n",
    "        name = re.sub(r'Callback$', '', self.__class__.__name__)\n",
    "        return camel2snake(name or 'callback')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This first callback is reponsible to switch the model back and forth in training or validation mode, as well as maintaining a count of the iterations, or the percentage of iterations ellapsed in the epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class TrainEvalCallback(Callback):\n",
    "    def begin_fit(self):\n",
    "        self.run.n_epochs=0.\n",
    "        self.run.n_iter=0\n",
    "    \n",
    "    def after_batch(self):\n",
    "        if not self.in_train: return\n",
    "        self.run.n_epochs += 1./self.iters\n",
    "        self.run.n_iter   += 1\n",
    "        \n",
    "    def begin_epoch(self):\n",
    "        self.run.n_epochs=self.epoch\n",
    "        self.model.train()\n",
    "        self.run.in_train=True\n",
    "\n",
    "    def begin_validate(self):\n",
    "        self.model.eval()\n",
    "        self.run.in_train=False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll also re-create our TestCallback"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestCallback(Callback):\n",
    "    def after_step(self):\n",
    "        if self.train_eval.n_iters>=10: return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'train_eval_callback'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cbname = 'TrainEvalCallback'\n",
    "camel2snake(cbname)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'train_eval'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TrainEvalCallback().name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from typing import *\n",
    "\n",
    "def listify(o):\n",
    "    if o is None: return []\n",
    "    if isinstance(o, list): return o\n",
    "    if isinstance(o, str): return [o]\n",
    "    if isinstance(o, Iterable): return list(o)\n",
    "    return [o]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class Runner():\n",
    "    def __init__(self, cbs=None, cb_funcs=None):\n",
    "        cbs = listify(cbs)\n",
    "        for cbf in listify(cb_funcs):\n",
    "            cb = cbf()\n",
    "            setattr(self, cb.name, cb)\n",
    "            cbs.append(cb)\n",
    "        self.stop,self.cbs = False,[TrainEvalCallback()]+cbs\n",
    "\n",
    "    @property\n",
    "    def opt(self):       return self.learn.opt\n",
    "    @property\n",
    "    def model(self):     return self.learn.model\n",
    "    @property\n",
    "    def loss_func(self): return self.learn.loss_func\n",
    "    @property\n",
    "    def data(self):      return self.learn.data\n",
    "\n",
    "    def one_batch(self, xb, yb):\n",
    "        self.xb,self.yb = xb,yb\n",
    "        if self('begin_batch'): return\n",
    "        self.pred = self.model(self.xb)\n",
    "        if self('after_pred'): return\n",
    "        self.loss = self.loss_func(self.pred, self.yb)\n",
    "        if self('after_loss') or not self.in_train: return\n",
    "        self.loss.backward()\n",
    "        if self('after_backward'): return\n",
    "        self.opt.step()\n",
    "        if self('after_step'): return\n",
    "        self.opt.zero_grad()\n",
    "\n",
    "    def all_batches(self, dl):\n",
    "        self.iters = len(dl)\n",
    "        for xb,yb in dl:\n",
    "            if self.stop: break\n",
    "            self.one_batch(xb, yb)\n",
    "            self('after_batch')\n",
    "        self.stop=False\n",
    "\n",
    "    def fit(self, epochs, learn):\n",
    "        self.epochs,self.learn = epochs,learn\n",
    "\n",
    "        try:\n",
    "            for cb in self.cbs: cb.set_runner(self)\n",
    "            if self('begin_fit'): return\n",
    "            for epoch in range(epochs):\n",
    "                self.epoch = epoch\n",
    "                if not self('begin_epoch'): self.all_batches(self.data.train_dl)\n",
    "\n",
    "                with torch.no_grad(): \n",
    "                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)\n",
    "                if self('after_epoch'): break\n",
    "            \n",
    "        finally:\n",
    "            self('after_fit')\n",
    "            self.learn = None\n",
    "\n",
    "    def __call__(self, cb_name):\n",
    "        for cb in sorted(self.cbs, key=lambda x: x._order):\n",
    "            f = getattr(cb, cb_name, None)\n",
    "            if f and f(): return True\n",
    "        return False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Third callback: how to compute metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class AvgStats():\n",
    "    def __init__(self, metrics, in_train): self.metrics,self.in_train = listify(metrics),in_train\n",
    "    \n",
    "    def reset(self):\n",
    "        self.tot_loss,self.count = 0.,0\n",
    "        self.tot_mets = [0.] * len(self.metrics)\n",
    "        \n",
    "    @property\n",
    "    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets\n",
    "    @property\n",
    "    def avg_stats(self): return [o/self.count for o in self.all_stats]\n",
    "    \n",
    "    def __repr__(self):\n",
    "        if not self.count: return \"\"\n",
    "        return f\"{'train' if self.in_train else 'valid'}: {self.avg_stats}\"\n",
    "\n",
    "    def accumulate(self, run):\n",
    "        bn = run.xb.shape[0]\n",
    "        self.tot_loss += run.loss * bn\n",
    "        self.count += bn\n",
    "        for i,m in enumerate(self.metrics):\n",
    "            self.tot_mets[i] += m(run.pred, run.yb) * bn\n",
    "\n",
    "class AvgStatsCallback(Callback):\n",
    "    def __init__(self, metrics):\n",
    "        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)\n",
    "        \n",
    "    def begin_epoch(self):\n",
    "        self.train_stats.reset()\n",
    "        self.valid_stats.reset()\n",
    "        \n",
    "    def after_loss(self):\n",
    "        stats = self.train_stats if self.in_train else self.valid_stats\n",
    "        with torch.no_grad(): stats.accumulate(self.run)\n",
    "    \n",
    "    def after_epoch(self):\n",
    "        print(self.train_stats)\n",
    "        print(self.valid_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(*get_model(data), loss_func, data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = AvgStatsCallback([accuracy])\n",
    "run = Runner(cbs=stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train: [0.31116248046875, tensor(0.9052)]\n",
      "valid: [0.141587158203125, tensor(0.9587)]\n",
      "train: [0.144049990234375, tensor(0.9572)]\n",
      "valid: [0.2177901611328125, tensor(0.9354)]\n"
     ]
    }
   ],
   "source": [
    "run.fit(2, learn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.2177901611328125, tensor(0.9354))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss,acc = stats.valid_stats.avg_stats\n",
    "assert acc>0.9\n",
    "loss,acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_cbf = partial(AvgStatsCallback,accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run = Runner(cb_funcs=acc_cbf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train: [0.109901865234375, tensor(0.9663)]\n",
      "valid: [0.12098992919921875, tensor(0.9649)]\n"
     ]
    }
   ],
   "source": [
    "run.fit(1, learn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using Jupyter means we can get tab-completion even for dynamic code like this! :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.12098992919921875, tensor(0.9649)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "run.avg_stats.valid_stats.avg_stats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 04_callbacks.ipynb to nb_04.py\r\n"
     ]
    }
   ],
   "source": [
    "!python notebook2script.py 04_callbacks.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
